import pandas as pd 
import pickle
import numpy as np
from helpers import timer 
import argparse
from bow_data import get_data
from sklearn.decomposition import LatentDirichletAllocation
'''
Self train lda model on unsupervised dataset
and save resulting in 
'''
@timer
def train_lda_unsupervised(datapath="data/data_agnews",lda_dim=4):
    print("Run lda unsupervised...")
    fitted_model_saved_folder = datapath
    vocab, train, test, unsup, valid = get_data(datapath)

    ## use unsupervised set
    unsup_tokens = unsup['tokens']
    unsup_counts = unsup['counts']
    vocab_size=len(vocab)
    ndoc = len(unsup_tokens)
    doc_word_matrix=np.zeros((ndoc,vocab_size))
    for i,(token,count) in enumerate(zip(unsup_tokens,unsup_counts)):
        for t,c in zip(token[0], count[0]):
            doc_word_matrix[i][t]=c

    print("document word matrix has shape: ",doc_word_matrix.shape)
    ## fit model

    print("start training lda...")
    lda = LatentDirichletAllocation(n_components=lda_dim)
    lda.fit(doc_word_matrix)
    print("training finished...")

    #save model
    filename = fitted_model_saved_folder+'/dim_'+str(lda_dim)+'_lda_model.sav'
    pickle.dump(lda, open(filename, 'wb'))
    print(f"model saved at {filename}")

    return lda

if __name__=='__main__':
    """"""
    # train_lda_unsupervised(lda_dim=50) 
    # train_lda_unsupervised(lda_dim=100) 
    # train_lda_unsupervised(lda_dim=500) 
    # train_lda_unsupervised(lda_dim=1000) 